[X86] fix inefficient llvm.masked.gather mask generation#175385
[X86] fix inefficient llvm.masked.gather mask generation#175385folkertdev wants to merge 3 commits intollvm:mainfrom
llvm.masked.gather mask generation#175385Conversation
|
@llvm/pr-subscribers-backend-x86 Author: Folkert de Vries (folkertdev) ChangesAn (incomplete) attempt at fixing #59789. The issue describes inefficient mask generation when using the portable masked gather intrinsic. I've replicated it here. https://godbolt.org/z/h7b7c5Tb1 The issue seems to be how the What this branch implements is to much earlier rewrite: Into That transformation is sufficient to prevent the scalarization of the mask vector construction. This does work in the case of the original issue, but I'm not really sure whether it is the right approach. It also runs into issues with avx512 code, I think it just kind of breaks the assumption that the mask is a boolean vector. So maybe this should actually be solved in Patch is 33.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175385.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 40ea3cb76bae4..92d7944d65a45 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57397,14 +57397,18 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
-static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
+static SDValue rebuildGatherScatter(SelectionDAG &DAG,
+ MaskedGatherScatterSDNode *GorS,
SDValue Index, SDValue Base, SDValue Scale,
- SelectionDAG &DAG) {
+ SDValue Mask = SDValue()) {
SDLoc DL(GorS);
+ if (!Mask.getNode())
+ Mask = GorS->getMask();
+
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
- SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
- Gather->getMask(), Base, Index, Scale } ;
+ SDValue Ops[] = {
+ Gather->getChain(), Gather->getPassThru(), Mask, Base, Index, Scale};
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
@@ -57412,8 +57416,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Gather->getExtensionType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
- SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
- Scatter->getMask(), Base, Index, Scale };
+ SDValue Ops[] = {
+ Scatter->getChain(), Scatter->getValue(), Mask, Base, Index, Scale};
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
@@ -57460,7 +57464,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOperand(0), NewShAmt);
SDValue NewScale =
DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
- return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+ return rebuildGatherScatter(DAG, GorS, NewIndex, Base, NewScale);
}
}
}
@@ -57478,7 +57482,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
// a split.
if (SDValue TruncIndex =
DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
- return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, TruncIndex, Base, Scale);
// Shrink any sign/zero extends from 32 or smaller to larger than 32 if
// there are sufficient sign bits. Only do this before legalize types to
@@ -57487,13 +57491,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOpcode() == ISD::ZERO_EXTEND) &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}
// Shrink if we remove an illegal type.
if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}
}
}
@@ -57518,13 +57522,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
DAG.getConstant(Adder, DL, PtrVT));
SDValue NewIndex = Index.getOperand(1 - I);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
// For non-constant cases, limit this to non-scaled cases.
if (ScaleAmt == 1) {
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
SDValue NewIndex = Index.getOperand(1 - I);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
}
}
@@ -57539,7 +57543,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
Index.getOperand(1 - I), Splat);
SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
}
}
@@ -57550,12 +57554,69 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
IndexVT = IndexVT.changeVectorElementType(*DAG.getContext(), EltVT);
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}
}
// With vector masks we only demand the upper bit of the mask.
SDValue Mask = GorS->getMask();
+
+ // Replace a mask that looks like:
+ //
+ // t9: v4i1 = bitcast t8
+ //
+ // With one that looks like:
+ //
+ // t25: i32 = zero_extend t8
+ // t26: v4i32 = X86ISD::VBROADCAST t25
+ // t32: v4i32 = and t26, t31
+ // t33: v4i32 = X86ISD::PCMPEQ t32, t31
+ //
+ // The default expansion from an integer to a mask vector generates a lot more
+ // instructions.
+ if (DCI.isBeforeLegalize()) {
+ EVT MaskVT = Mask.getValueType();
+
+ if (MaskVT.isVector() && MaskVT.getVectorElementType() == MVT::i1 &&
+ Mask.getOpcode() == ISD::BITCAST) {
+
+ SDValue Bits = Mask.getOperand(0);
+ if (Bits.getValueType().isScalarInteger()) {
+ unsigned NumElts = MaskVT.getVectorNumElements();
+ if (NumElts == 4 || NumElts == 8) {
+
+ EVT ValueVT = N->getValueType(0);
+ EVT IntMaskVT = ValueVT.changeVectorElementTypeToInteger();
+ if (!IntMaskVT.isSimple() || !TLI.isTypeLegal(IntMaskVT))
+ return SDValue();
+
+ MVT MaskVecVT = IntMaskVT.getSimpleVT();
+ MVT MaskEltVT = MaskVecVT.getVectorElementType();
+
+ if (MaskVecVT.getVectorNumElements() != NumElts)
+ return SDValue();
+
+ SDValue BitsElt = DAG.getZExtOrTrunc(Bits, DL, MaskEltVT);
+ SDValue Bc = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVecVT, BitsElt);
+
+ SmallVector<SDValue, 8> Lanes;
+ Lanes.reserve(NumElts);
+ for (unsigned i = 0; i < NumElts; ++i) {
+ uint64_t Bit = 1ull << i;
+ Lanes.push_back(DAG.getConstant(Bit, DL, MaskEltVT));
+ }
+
+ SDValue LaneBits = DAG.getBuildVector(MaskVecVT, DL, Lanes);
+ SDValue And = DAG.getNode(ISD::AND, DL, MaskVecVT, Bc, LaneBits);
+ SDValue NewMask =
+ DAG.getNode(X86ISD::PCMPEQ, DL, MaskVecVT, And, LaneBits);
+
+ return rebuildGatherScatter(DAG, GorS, Index, Base, Scale, NewMask);
+ }
+ }
+ }
+ }
+
if (Mask.getScalarValueSizeInBits() != 1) {
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll b/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll
new file mode 100644
index 0000000000000..016137ed7cc86
--- /dev/null
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll
@@ -0,0 +1,600 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=x86_64-unknown-unknown -O3 -mattr=+avx2 -mcpu=skylake < %s | FileCheck %s --check-prefix=AVX2
+
+define <4 x i32> @gather_avx_dd_128(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dd_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vmovaps %xmm0, %xmm1
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: movq %rsi, %rdi
+; AVX2-NEXT: movl $4, %esi
+; AVX2-NEXT: jmp llvm.x86.avx2.gather.d.d.128@PLT # TAILCALL
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %m32 = sext <4 x i1> %m to <4 x i32>
+ %res = tail call <4 x i32> @llvm.x86.avx2.gather.d.d.128(<4 x i32> zeroinitializer, ptr %data, <4 x i32> %indices, <4 x i32> %m32, i8 4)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @gather_portable_dd_128(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dd_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %idx64 = zext <4 x i32> %indices to <4 x i64>
+ %ptrs = getelementptr i32, ptr %data, <4 x i64> %idx64
+ %res = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> %m, <4 x i32> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+define <8 x i32> @gather_avx_dd_256(<8 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dd_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vmovd %edi, %xmm1
+; AVX2-NEXT: vpbroadcastb %xmm1, %ymm1
+; AVX2-NEXT: vmovdqa {{.*#+}} ymm2 = [1,2,4,8,16,32,64,128]
+; AVX2-NEXT: vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT: vpcmpeqd %ymm2, %ymm1, %ymm2
+; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT: vpgatherdd %ymm2, (%rsi,%ymm0,4), %ymm1
+; AVX2-NEXT: vmovdqa %ymm1, %ymm0
+; AVX2-NEXT: retq
+ %m = bitcast i8 %maskbits to <8 x i1>
+ %m32 = sext <8 x i1> %m to <8 x i32>
+ %res = tail call <8 x i32> @llvm.x86.avx2.gather.d.d.256(<8 x i32> zeroinitializer, ptr %data, <8 x i32> %indices, <8 x i32> %m32, i8 4)
+ ret <8 x i32> %res
+}
+
+define <8 x i32> @gather_portable_dd_256(<8 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dd_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm0
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vmovd %edi, %xmm2
+; AVX2-NEXT: vpbroadcastd %xmm2, %ymm2
+; AVX2-NEXT: vmovdqa {{.*#+}} ymm3 = [1,2,4,8,16,32,64,128]
+; AVX2-NEXT: vpand %ymm3, %ymm2, %ymm2
+; AVX2-NEXT: vpcmpeqd %ymm3, %ymm2, %ymm2
+; AVX2-NEXT: vextracti128 $1, %ymm2, %xmm3
+; AVX2-NEXT: vpxor %xmm4, %xmm4, %xmm4
+; AVX2-NEXT: vpxor %xmm5, %xmm5, %xmm5
+; AVX2-NEXT: vpgatherqd %xmm3, (%rsi,%ymm0,4), %xmm5
+; AVX2-NEXT: vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm4
+; AVX2-NEXT: vinserti128 $1, %xmm5, %ymm4, %ymm0
+; AVX2-NEXT: retq
+ %m = bitcast i8 %maskbits to <8 x i1>
+ %idx64 = zext <8 x i32> %indices to <8 x i64>
+ %ptrs = getelementptr i32, ptr %data, <8 x i64> %idx64
+ %res = tail call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> %m, <8 x i32> zeroinitializer)
+ ret <8 x i32> %res
+}
+
+define <2 x i32> @gather_avx_qd_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qd_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vpbroadcastq {{.*#+}} xmm2 = [1,2,1,2]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: movq %rsi, %rdi
+; AVX2-NEXT: movl $4, %esi
+; AVX2-NEXT: jmp llvm.x86.avx2.gather.q.d.128@PLT # TAILCALL
+ %m2 = trunc i8 %maskbits to i2
+ %m = bitcast i2 %m2 to <2 x i1>
+ %idx64 = zext <2 x i32> %indices to <2 x i64>
+ %m32 = sext <2 x i1> %m to <2 x i32>
+ %res = tail call <2 x i32> @llvm.x86.avx2.gather.q.d.128(<2 x i32> zeroinitializer, ptr %data, <2 x i64> %idx64, <2 x i32> %m32, i8 4)
+ ret <2 x i32> %res
+}
+
+define <2 x i32> @gather_portable_qd_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_qd_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: movl %edi, %eax
+; AVX2-NEXT: andb $2, %al
+; AVX2-NEXT: shrb %al
+; AVX2-NEXT: andb $1, %dil
+; AVX2-NEXT: vmovd %edi, %xmm1
+; AVX2-NEXT: vpinsrb $8, %eax, %xmm1, %xmm1
+; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm2 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; AVX2-NEXT: vpslld $31, %xmm0, %xmm1
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqd %xmm1, (%rsi,%xmm2,4), %xmm0
+; AVX2-NEXT: retq
+ %m2 = trunc i8 %maskbits to i2
+ %m = bitcast i2 %m2 to <2 x i1>
+ %idx64 = zext <2 x i32> %indices to <2 x i64>
+ %ptrs = getelementptr i32, ptr %data, <2 x i64> %idx64
+ %res = tail call <2 x i32> @llvm.masked.gather.v2i32.v2p0(<2 x ptr> %ptrs, i32 4, <2 x i1> %m, <2 x i32> zeroinitializer)
+ ret <2 x i32> %res
+}
+
+define <4 x i32> @gather_avx_qd_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qd_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %idx64 = zext <4 x i32> %indices to <4 x i64>
+ %m32 = sext <4 x i1> %m to <4 x i32>
+ %res = tail call <4 x i32> @llvm.x86.avx2.gather.q.d.256(<4 x i32> zeroinitializer, ptr %data, <4 x i64> %idx64, <4 x i32> %m32, i8 4)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @gather_portable_qd_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_qd_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %idx64 = zext <4 x i32> %indices to <4 x i64>
+ %ptrs = getelementptr i32, ptr %data, <4 x i64> %idx64
+ %res = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> %m, <4 x i32> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+define <2 x i64> @gather_avx_dq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dq_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vmovaps %xmm0, %xmm1
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqq %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: movq %rsi, %rdi
+; AVX2-NEXT: movl $8, %esi
+; AVX2-NEXT: jmp llvm.x86.avx2.gather.d.q.128@PLT # TAILCALL
+ %m2 = trunc i8 %maskbits to i2
+ %m = bitcast i2 %m2 to <2 x i1>
+ %m64 = sext <2 x i1> %m to <2 x i64>
+ %res = tail call <2 x i64> @llvm.x86.avx2.gather.d.q.128(<2 x i64> zeroinitializer, ptr %data, <2 x i32> %indices, <2 x i64> %m64, i8 8)
+ ret <2 x i64> %res
+}
+
+define <2 x i64> @gather_portable_dq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dq_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: movl %edi, %eax
+; AVX2-NEXT: andl $1, %eax
+; AVX2-NEXT: negq %rax
+; AVX2-NEXT: vmovq %rax, %xmm1
+; AVX2-NEXT: andb $2, %dil
+; AVX2-NEXT: shrb %dil
+; AVX2-NEXT: movzbl %dil, %eax
+; AVX2-NEXT: negq %rax
+; AVX2-NEXT: vmovq %rax, %xmm2
+; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
+; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm2 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqq %xmm1, (%rsi,%xmm2,8), %xmm0
+; AVX2-NEXT: retq
+ %m2 = trunc i8 %maskbits to i2
+ %m = bitcast i2 %m2 to <2 x i1>
+ %idx64 = zext <2 x i32> %indices to <2 x i64>
+ %ptrs = getelementptr i64, ptr %data, <2 x i64> %idx64
+ %res = tail call <2 x i64> @llvm.masked.gather.v2i64.v2p0(<2 x ptr> %ptrs, i32 8, <2 x i1> %m, <2 x i64> zeroinitializer)
+ ret <2 x i64> %res
+}
+
+define <4 x i64> @gather_avx_dq_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dq_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vmovd %edi, %xmm1
+; AVX2-NEXT: vpbroadcastd %xmm1, %ymm1
+; AVX2-NEXT: vmovdqa {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT: vpcmpeqq %ymm2, %ymm1, %ymm2
+; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT: vpgatherdq %ymm2, (%rsi,%xmm0,8), %ymm1
+; AVX2-NEXT: vmovdqa %ymm1, %ymm0
+; AVX2-NEXT: retq
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %m64 = sext <4 x i1> %m to <4 x i64>
+ %res = tail call <4 x i64> @llvm.x86.avx2.gather.d.q.256(<4 x i64> zeroinitializer, ptr %data, <4 x i32> %indices, <4 x i64> %m64, i8 8)
+ ret <4 x i64> %res
+}
+
+define <4 x i64> @gather_portable_dq_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dq_256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %ymm0
+; AVX2-NEXT: vmovdqa {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %ymm2, %ymm0, %ymm0
+; AVX2-NEXT: vpcmpeqq %ymm2, %ymm0, %ymm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vpgatherqq %ymm2, (%rsi,%ymm1,8), %ymm0
+; AVX2-NEXT: retq
+ %m4 = trunc i8 %maskbits to i4
+ %m = bitcast i4 %m4 to <4 x i1>
+ %idx64 = zext <4 x i32> %indices to <4 x i64>
+ %ptrs = getelementptr i64, ptr %data, <4 x i64> %idx64
+ %res = tail call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> %ptrs, i32 8, <4 x i1> %m, <4 x i64> zeroinitializer)
+ ret <4 x i64> %res
+}
+
+define <2 x i64> @gather_avx_qq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qq_128:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT: vmovd %edi, %xmm0
+; AVX2-NEXT: vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT: vmovdqa {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT: vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: vpcmpeqq %xmm2, %xmm0, %xmm2
+; AVX2-NEXT: vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: movq %rsi, %rdi
+; AVX2-NEXT: movl $8, %esi
+; AVX2-NEXT: jmp llvm.x86.avx2.gather.q.q.128@PLT # TAILCALL
+ %m2 = trunc i8 %maskbits to i2
+ %m = bitcast i2 %m2 to <2 x i1>
+ %idx64 = zext <2 x i32> %indices to <2 x i64>
+ %m64 = sext <2 x i1> %m to <2 x i6...
[truncated]
|
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
Hmm, doing it in |
RKSimon
left a comment
There was a problem hiding this comment.
Please investigate the CodeGen/X86/masked_gather_scatter.ll failures
|
I will, but could you provide some guidance on what the best place to tackle this is? in the combine, or in the lower? |
264d4aa to
44f95a1
Compare
|
So, a bit of a cop out, but the tests in that existing file all assume avx512. The inefficient mask generation is only a problem when avx512 is not available (it has those special mask registers). So the optimization in this PR now just does not run if avx512 is available. I also experimented further with performing the optimization in |
…gets Test coverage to help llvm#175385
…l vectors for masked load/store Test coverage to help llvm#175385
…nsion before it might get split by legalisation Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#175385
|
@folkertdev I ended up coming up with #175769, which I think should address the issue reusing existing code. |
|
Nice, thanks for looking into it! |
Sorry for poaching it :( |
…or extension before it might get split by legalisation (#175769) Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to #175385 Fixes #59789
… for "fast-gather" avx2 targets (llvm#175736) Test coverage to help llvm#175385
…l vectors for masked load/store (llvm#175746) Test coverage to help llvm#175385
…or extension before it might get split by legalisation (llvm#175769) Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#59789
…or extension before it might get split by legalisation (llvm#175769) Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#59789
An (incomplete) attempt at fixing #59789.
The issue describes inefficient mask generation when using the portable masked gather intrinsic. I've replicated it here.
https://godbolt.org/z/h7b7c5Tb1
The issue seems to be how the
maskbitmask is converted into a vector. Based on the logs that ultimately happens whenmasked_gatheris lowered:What this branch implements is to much earlier rewrite:
Into
That transformation is sufficient to prevent the scalarization of the mask vector construction. This does work in the case of the original issue, but I'm not really sure whether it is the right approach. It also runs into issues with avx512 code, I think it just kind of breaks the assumption that the mask is a boolean vector.
So maybe this should actually be solved in
LowerMGATHER? I think I'm just kind of missing something that would make this simpler and more robust.